package cnki.tpi.kbatis.handler;

import cnki.tpi.kbatis.config.Configuration;
import cnki.tpi.kbatis.config.MappedStatement;
import cnki.tpi.kbatis.sqlsource.BoundSql;
import cnki.tpi.kbatis.sqlsource.ParameterMapping;
import cnki.tpi.kbatis.utils.DataSourceUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Field;
import java.sql.*;
import java.util.List;


/**
 * @ClassName StatementHandler
 * @Description
 * @Author 小黄
 * @Date 2019/11/16 15:43
 * @Version 1.0
 */
@Data
@Slf4j
public class StatementHandler {
    private String sql;
    private Connection connection;
    private MappedStatement mappedStatement;
    private Configuration configuration;
    private Object param;
    private BoundSql boundSql;

    public StatementHandler(MappedStatement mappedStatement, Configuration configuration, Object param, BoundSql boundSql, Connection connection, String sql) {
        this.connection = connection;
        this.sql = sql;
        this.mappedStatement = mappedStatement;
        this.configuration = configuration;
        this.param = param;
        this.boundSql = boundSql;
    }

    public ResultSet executeQuery(String sql) throws Exception {
        PreparedStatement prepareStatement = null;
        try {
            // 创建Statement
            prepareStatement = connection.prepareStatement(sql);
            // 执行Statement
//            sql = sql.replaceAll("\\\\", "\\\\\\\\");
//            sql = sql.replaceAll("'", "\\\\'");
            return prepareStatement.executeQuery(sql);
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            release(prepareStatement);
        }
        return null;
    }

    public Boolean execute() throws Exception {
        Statement pst = null;
        try {
            sql = parameterize(sql, mappedStatement, boundSql, param);
            // 创建Statement
            pst = connection.createStatement();
            System.out.println(sql);
            // 执行Statement
            return pst.execute(sql);
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            release(pst);
        }
        return Boolean.FALSE;
    }

    private void parameterize(PreparedStatement prepareStatement, MappedStatement mappedStatement, BoundSql boundSql, Object param) throws Exception {
        // 先判断入参类型
        Class<?> parameterTypeClass = mappedStatement.getParameterTypeClass();
        if (parameterTypeClass == Integer.class) {
            prepareStatement.setObject(1, Integer.parseInt(String.valueOf(param)));
        } else if (parameterTypeClass == String.class) {
            prepareStatement.setObject(1, String.valueOf(param));
        } else {// 自定义对象类型
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            for (int i = 0; i < parameterMappings.size(); i++) {
                // 获取#{}中的属性名称
                ParameterMapping parameterMapping = parameterMappings.get(i);
                String name = parameterMapping.getName();
                // 根据属性名称，获取入参对象中对应的属性的值
                // 要求#{}中的属性名称和入参对象中的属性名称一致
                Field field = parameterTypeClass.getDeclaredField(name);
                field.setAccessible(true);
                Object value = field.get(param);
                prepareStatement.setObject(i + 1, value);
            }
        }
    }

    /**
     * 设置sql里的参数
     *
     * @param sql
     * @param mappedStatement
     * @param boundSql
     * @param param
     * @return
     * @throws Exception
     */
    public String parameterize(String sql, MappedStatement mappedStatement, BoundSql boundSql, Object param) throws Exception {
        // 先判断入参类型

        Class<?> parameterTypeClass = mappedStatement.getParameterTypeClass();
        if (parameterTypeClass == Integer.class) {
            sql = sql.replace("?", String.valueOf(param));
        } else if (parameterTypeClass == String.class) {
            sql = sql.replace("?", String.valueOf(param));
        } else {// 自定义对象类型
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            String[] sqlArr = sql.split("\\?");
            if (parameterMappings != null && parameterMappings.size() > 0) {
                if (sqlArr != null && sqlArr.length > 0 && sqlArr.length - 1 == parameterMappings.size()) {
                    StringBuilder builder = new StringBuilder();
                    for (int i = 0; i < parameterMappings.size(); i++) {
                        // 获取#{}中的属性名称
                        ParameterMapping parameterMapping = parameterMappings.get(i);
                        String name = parameterMapping.getName();
                        // 根据属性名称，获取入参对象中对应的属性的值
                        // 要求#{}中的属性名称和入参对象中的属性名称一致
                        List<Field> fields = DataSourceUtil.getFields(parameterTypeClass);
                        Field targetField = null;
                        for (Field field : fields) {
                            if (name.equalsIgnoreCase(field.getName())) {
                                targetField = field;
                                break;
                            }
                        }
                        targetField.setAccessible(true);
                        Object value = targetField.get(param);
                        String sss =null;
                        String parameters =param.toString();
                        //这里改了
                        if (value != null) {
                             sss = value.toString();
                            if (!sss.contains("BELONG_TO") && !sss.contains("标准名称 =")  && !sss.contains("标准号 =") && !sss.contains("编码=") && !sss.contains("主题 =") && !sss.contains("全文 %=")
                                    && !sss.contains("发布单位 =") && !sss.contains("起草单位 =") && !sss.contains("归口单位 =")) {



                                sss = sss.replaceAll("\\\\", "\\\\\\\\");
                                sss = sss.replaceAll("'", "\\\\'");
                            }
                        }
                        builder.append(sqlArr[i] + sss);
                    }
                    builder.append(sqlArr[sqlArr.length - 1]);
                    sql = builder.toString();
                }
            }
        }
        return sql;
    }

    private void release(Statement pstm) {
        if (pstm != null) {
            try {
                pstm.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}
